{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use Expected SARSA to Play Taxi-v3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "import logging\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "logging.basicConfig(level=logging.INFO,\n",
    "        format='%(asctime)s [%(levelname)s] %(message)s',\n",
    "        stream=sys.stdout, datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] env: <TaxiEnv<Taxi-v3>>\n",
      "00:00:00 [INFO] action_space: Discrete(6)\n",
      "00:00:00 [INFO] observation_space: Discrete(500)\n",
      "00:00:00 [INFO] reward_range: (-inf, inf)\n",
      "00:00:00 [INFO] metadata: {'render.modes': ['human', 'ansi']}\n",
      "00:00:00 [INFO] _max_episode_steps: 200\n",
      "00:00:00 [INFO] _elapsed_steps: None\n",
      "00:00:00 [INFO] id: Taxi-v3\n",
      "00:00:00 [INFO] entry_point: gym.envs.toy_text:TaxiEnv\n",
      "00:00:00 [INFO] reward_threshold: 8\n",
      "00:00:00 [INFO] nondeterministic: False\n",
      "00:00:00 [INFO] max_episode_steps: 200\n",
      "00:00:00 [INFO] _kwargs: {}\n",
      "00:00:00 [INFO] _env_name: Taxi\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('Taxi-v3')\n",
    "env.seed(0)\n",
    "for key in vars(env):\n",
    "    logging.info('%s: %s', key, vars(env)[key])\n",
    "for key in vars(env.spec):\n",
    "    logging.info('%s: %s', key, vars(env.spec)[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExpectedSARSAAgent:\n",
    "    def __init__(self, env):\n",
    "        self.gamma = 0.99\n",
    "        self.learning_rate = 0.2\n",
    "        self.epsilon = 0.01\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "\n",
    "    def reset(self, mode=None):\n",
    "        self.mode = mode\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory = []\n",
    "\n",
    "    def step(self, observation, reward, done):\n",
    "        if self.mode == 'train' and np.random.uniform() < self.epsilon:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        else:\n",
    "            action = self.q[observation].argmax()\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory += [observation, reward, done, action]\n",
    "            if len(self.trajectory) >= 8:\n",
    "                self.learn()\n",
    "        return action\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "    def learn(self):\n",
    "        state, _, _, action, next_state, reward, done, _ = \\\n",
    "                        self.trajectory[-8:]\n",
    "\n",
    "        v = (self.q[next_state].mean() * self.epsilon + \\\n",
    "                self.q[next_state].max() * (1. - self.epsilon))\n",
    "        target = reward + self.gamma * v * (1. - done)\n",
    "        td_error = target - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error\n",
    "\n",
    "\n",
    "agent = ExpectedSARSAAgent(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] ==== train ====\n",
      "00:00:06 [INFO] ==== test ====\n",
      "00:00:06 [INFO] average episode reward = 7.67 ± 2.60\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAiDklEQVR4nO3deXyU5bn/8c+VHUjCHrYEwhJAUESJCBRRNqFiRVuPYhf1Z49Uj1u12rrVSvuytfacHrX9daGbtaetovVUW7VWq9YuKkYrAgqKAhJA9n1PuM8f80yYSWZJMsk8M3m+79crL2buZ7ueGeaae677nmfMOYeIiARLjt8BiIhI+in5i4gEkJK/iEgAKfmLiASQkr+ISADl+R1Ac/Xq1ctVVlb6HYaISFZ5/fXXtzrnejduz5rkX1lZSU1Njd9hiIhkFTNbG6tdZR8RkQBS8hcRCSAlfxGRAFLyFxEJICV/EZEAUvIXEQkgJX8RkQBS8m+BfYfqqN2xv+H+7oNH2LDzAADvb9lLXf3RhmVPLd3Ijn2Ho7Zfs3Ufh+rq2XXgCGu37WPFR7t5pGYdyS6rvSfiOI29u2kPi1dvb7hfu2M/ew4eAWDttn08tXQj67bvj9rm+RWbeGHFZgBWfhTaftXmvezaf4TaHfs5eKSeNVv3JXs4+PWra3ln424g9Ng0Pk5jO/YdZvPugw33123fz95DddQfdSx6bV3U49dYSy89vn7ngYbHIdG2rb2k+bL1u3h97Q4AXv1gG39YsoGjR0P7+nDbfvYfrmtynJUf7Ylqqz967NhH6o+yavNeADbvPsji1dv55/tbWxXb6q37OHC4HoC6+qO8uW5nQ2wt0fixeW/THvYdqmPZ+l1NzgXgsTdq2Xuorkl7LAeP1LPa+z+2t9HrqvGx122Pfjx3HTjC+p0H+MGLq9i4q+nronHc727aE/f8nXNs3HWAnfsP8/BrH1JXf5S6+qMNtz/adZCnl25k295DCc9nae0u3vhwR+KTjsM5x95DdSyqWUf9Uce67fvZuvcQa7clfw2mwvy6nr+ZzQbuA3KBnzrn7k60fnV1tWvvL3m9uHIz33t+FYu+MJHcHGuy/Kz7/sbbG3ez4huzGfnVPzW05xiE/289/6XTeendLdz5h7cblo8p78pbtbviHndk3xJWfLSHcYO68+a6nRTk5rBg7mi+/OhbUetNG1nG8ys206e0kLPH9Odnf18dd5+VPTuzZlv0C+qe88c02WcyvUsK2bIn8X/88u6dqN1x7EXYrXM+xYV57DpwhLPH9Oe3iz9kSK8ufOC92GeO6sOzb29KeuweXQrYHvEG2rkgl/1eUgOYc0I/nly6sUXnE0+n/FwmDOnBCyu3RLVf9rHB/Pwfocf5uulV3PeX9xg/uEfUG24ifUuL+CjiDc9Pl582mJ/8bTXl3Tux68AR9hxsmqinjyxj+/7D/OvDnUn3N6RXFzoX5rJs/W4K83I4VBf/zTvswuoKHq5Zl3Cd/l2L2LAr9mNWPag7NWubJtlLJ1Xyj1Vbec97A02kIC/U5z3cjHgbb/eJMf3pXVLIs29/RElRPm+u2xm1zsxRfTh9eG9u//2yqPaLxleQl5PDr16J+X2rpN6680xKi/Jbta2Zve6cq27S7kfyN7Nc4F1gJlALvAZc5Jx7O9426Uj+Y7/+Z3buP8JNs0bQuSCX//exwVHLK29+EoCvzB7Jt/+0ol1jEREJW/2tszBr2iFtjnjJ36/LO4wHVjnnPgAws4eAuUDc5N/eVny0m537Q2WC7zyzEqAh+T+/YhPb9h7rgdYfbVmPQUQkFa1N/In4VfMfAER+9qv12qKY2XwzqzGzmi1btjRe3KZm3/u3uMsue6CGmyLKJbk5Girx27IFs3jxxjOi2rp1jv5Y/PW5owE4d2z/qPaqsmIqe3aOavvNv58adb9/16JmxTFxSE9+cekpfPeCE7l9znF8ZfZInv/S6UwbWRZz/aevO41xg7pz9ydPiLn8vJOOvQz+fP0Urpo6lAlDejS0XTqpMmr9kwd247MTBnLP+WNY8rUz+f6nT4q53+pB3Rtuf+LE/lH3AW48czhVZcVxz3PFN2bz75NDnaFrpg3j3gvHMuO4Pnzj3OOBUHnr6etOa1j/zk+Marj9zfNO4MHLxrPgnNEx9/2zS6p5eP4Ebj1rJH+8ZjKnVHZn0RcmUnP7jIZ17r8o+rzM4PY5x3HppEpevPEMHr1iIi/dNJWXbpraZP+/ufxUam6fwWu3zYhqLyspjBnPCQO68t0LTuS0ql4Nbf82rpxvf+oEhvc59hjdN28sI/qU8MmTmqQuAG49aySnDg49d+ePK+eVW6bz2H9M4h83T+O+eWN59vopLF8wi6V3nhm13Q8+c3LD7c9OGMhzN5wec/+p8qvnH+ttrEn9yTm3EFgIobJPewfV2E2PLOGbMV6ki1dvS3covvjijCrufe69mMvyc42bZo3gm08dK3/98ZrJLKndyZJ1O1lUU0tBbg6HIwZxF986nbLSooby2dyx/Xn8zQ1N9r3m7jn8/xdWNXwCi6W4MI/iwjxumDmcZ5Z/xPINu1lwzmjWbtvPd599l6evO43j+pVy8cRKAK6dXsW0//orAJdPGcIF1RX88p9r+NoTyzl7TD8mDevFk9dOZuFLH/D4mxv4zIRBfOeZlRTl53DwyLFz+OVl43lj7Q7u+8t73HrWSOZPGRozvvlThvC8N6j+1bNH8Y0/hj7UHtevlN9dOQmAC0+p4IZFS1i/8wBfnzuavJwchpUVc0F1BVV9iulVXMhNs0YC8KO/vk/tjv3cNuc4Vm/dx1/f3cK9F47l3EaJZ9bovlH3jx9QyrL1u3n0ykkNj/v3IhJpuO3qaVVcPa2K97fsZbr3OJ0/rpxHX6/l1ME9KMrP5YYzh5ObY1xx+lC6FOY1HPtzEwY17G/ikJ6cOboPl0yqpF+3TkwdUdZQY58yvDeVvbowoFsnrv7NGxypP8r1M4cz/bg+AJw6pCcAj1wxqWF/10wbxoHD9ZxzYn+u/e2/GtpXf2tO1HlW9urS5DkIjx1MHNKzoef86BUTOf9HL3PXecdz3kkDGHXHM1HrF+Tl8IdrJgPwyZPLGwb2L/HedA/VHeWOx5eHzqeqN3PHhh6D4qI8Hnx5Lf+8eRoL/rCc6cf14YLqCuZPGcrGXQfoW1qEmdHX61QMGBv9vP3xmsnUH3UM71NCp4LchvZPnVzOsARvyqnwq+Y/EbjTOTfLu38LgHPuW/G2acua/7rt++nfrVPUoG74RdDYby4/lU//5NU2OW62uXZ6Fa+v3c4/VjV9s3vzjpl061zAsFufoqQoj3/dcaz38vyKTVz2QA15OUadNxL+5LWTGd2/KwCb9xzEMO58YjlPLt3IYO+FG579sebuYy/sypuf5Nyx/XltzQ7WR8x4ilwn0tGjjnc372Fk39Ko9sN1Rxl++9Oh40TUT9/ZuJthZcXk54YSVP1Rx3ve9m9v2M3IviUMufUpAF65ZTp9uxZxqK6ex95Yz4XVFeTEmBgAoRkcj9TUMmdMP4rycxnq7SNe3G3p4JF6fv+v9Vzgxeecw8x4q3YnR+od4yJ6/as27+WjXQeZHNHL/e3iD+naKZ/Zo/vycM06PnnyAArzcmMdKu027z5ITo7Rqzh2rx3ghZWb+WDLPs46oS9v1e5q8oYY6f0te9m06yAlRfmcUN416fH3H67j1seW8rmJgxg3qEfS9VvrhRWbmf+rGt746kxKWjnQG5ZpNf/XgCozGwysB+YBn07Hgddt389p97zA1VOHceOsEUnXb+mMgI7EgB99dhwn3PnnhrYbZg7n85MH06Uw9F9n2YJZNC5HhvsT9REdi3DiBygrCfV+wp8Kbv74SCYN7Rl1nLDlC2ZRmJfDafe80KyYc3KsSeIHot7oI+unx/UrbbJeePtR/UP/9iktZNPuQzjvw2lhXi4XjR+YMA4z44JTKpoVc1srys9lXkR84fMdU96tybrDyoqb9Cwjzy3ZeaZbWWnyctzUEWVM9V7a/bp2Srju0N7FDO3d/J5154I87p0Xu7TWlqaOLOO9u85q12P4Urx2ztUBVwPPAO8Ai5xzy9Nx7M3etMV/NHMO9aW/eK09w0mLR6+YyF++FF03fOaLU/jK7JFRvcDGzKCkKJ+PDevZ0JabYw2JH0KJpnGvMDztNdmHyh6dCwAoKcojJ86AVpfCPPJyc6LmxJ/YjB5aY+HcH67BtkT/bqEEEmv6b3NVD+rOf/7bia3eXqSt+fZjLs65p4Cn/Ds+LN+wK/SRvAXv/NkiN8caEmZ1ZQ+ORNTe/3jNZEb0LWFE3xI+M2Egv/j7Gv77uXeb7MO8oZnvXXQys+99qeGNM5lwKbFnlwK27TtMvziDp1/9xCjGDuzGxCE9o+rqsYRz/6u3To87UJeImfHcDafHjSWRn1xczd/f29rwiaU1Hr1yUvKVRNIocNNWNkV84WbO/X9vGNzqaC4aH11yyI3oWR8/4FjPubQonwvjlCfCm/ToUsCnxpU3+9jhRD1uUHf+5/On8tKXm87AgNCg7UXjB2JmTUpHjYXfUHJzrNXT3oaVFUd9ammuXsWFTQZWRbJd4JL/f/z6DSDG1KIs9OKNZ3Dt9KqotvBUtC4F0UkuUb6Mt6z1RQ7XsN/JVb0aBlMTiVf2CTvqJf9k64lI8wQu+Xcklb26cMPM4VFt550Uu4femt5ya/NsuNZvLXj7SFZODw/ChqcNikhq9ErKIN85f0xK2zeeudJckXk38stNrS2vhMs+Ldk82bF+fPE4Hp4/geJWlG1EpKngJv+IqShffnSJj4Ecc2JFtyZtD8+fwJq750R9kQbgnzdPa7hd0SM0G+XRKyYeWyFOLi2M1XOOWPeZ66cca45oDyfdLgXJ53uHp0S2pESTrOdfWpTf8CUgEUmdulHAoppav0MAoLs39TFSOOFNruoVdUXA8PRDgBe+dAYOktbW75s3lhNjzPWOVFKUzxemDOHHL30QVba5/LQhFOTm8JlGb0KxNLyvtmHPX0TaVnB7/hmod0khi2+b3uLt8nJzYib+O84exTknHruuzdyxA2J+Db5xbb4hd0c0F+TlcPmUIc0avG1F7heRNFPPP8OkMpc8ioPLJg9Ovh5Na/PhaZWtTd4N27ewN9+ntJCrpg5r5VFFpCWU/KWJsRXdgdVRl2RoiXDZp6VfiH311hnJVxKRNhHY5N8R5vnHMnFoaIzg9BG9m71N4xw9Z0w/xg6cxoBuia+LEk94wFdlH5HMFdzkn2XZv7mJdGxFtxZfOTJcnoms0rQ28QN8/Ph+vPz+tobLEYtI5gls8s80gxr9uEgiBc0YdG2NtuqpF+Xncs/5uoiZSCbTbJ8M0ZxZNAAzjuvDu3d9vE2PrfKMSPAEtufvMqzqH/mjOj/+3Dg+2LKP+VOGpOXY4XKP5tqLBEdgk38mS/TLQ+1JqV8kOAJb9lm2frffIUTxs9fdkguwiUjHENjkLxEayj7+hiEi6aPknyUGdA9NvTxpYLd2O4Y+AYgEh2r+GcIl+eLB6P5dee6GKQzp1fY/Oakev0jwqOfvoy4FuUxswWWKh5WVkJPCj4jHY01uiEhHp+Tvo+Vfn83X544GMmOapf8RiEi6KPmnWUt6+umSCW88IpJeSv5p0rVTPgBnxLngWrKaf3sKp369B4gEhwZ802TJ185k3fb9lHfvRK/iQspKC4HMSria7SMSHEr+aVTRI3Txtk+NK29oy7ari4pIx6CyT4bws+4e62cbRaRjU/JPwZmj+rTZvvys+af6s40ikn2U/FMQLuOkIpN625r1IxIcSv4piJcqB3TrxKdPHdisfWRCzT8DQhCRNFPyT0G8pPnktZMZM6BlP37uZ6+7S0EevYoL+Ma5o32LQUTSS7N92kG2lU9yc4ya22f6HYaIpJF6/u0g/IUugF7FBfz04uqk2/g54CsiwaPk386mjSxjRoJZQVn2IUFEOgglf5+pwy8ifkgp+ZvZd8xshZm9ZWb/a2bdIpbdYmarzGylmc2KaB9nZku9ZfdbthXIW6i5l0zo4A+DiGSYVHv+zwLHO+fGAO8CtwCY2ShgHjAamA38wMxyvW1+CMwHqry/2SnGkNFcMydSquYvIumUUvJ3zv3ZOVfn3X0FCF+0Zi7wkHPukHNuNbAKGG9m/YBS59zLLpTtHgTOTSWG5jp4pJ6tew+l41Atog6/iPihLad6XgY87N0eQOjNIKzWazvi3W7c3u5GfvVPbb7P5nTWk5V91OEXET8k7fmb2XNmtizG39yIdW4D6oBfh5ti7MolaI937PlmVmNmNVu2bEkWalz7D9clXykFw/vE/13d5pZ9VPMXkXRK2vN3zs1ItNzMLgHOBqa7Y4XrWqAiYrVyYIPXXh6jPd6xFwILAaqrq1vdRz7t2y+0dtOEwom9Z5dCYG+7HENEpD2kOttnNvAV4Bzn3P6IRU8A88ys0MwGExrYXeyc2wjsMbMJ3iyfi4HHU4mhObbtO9zeh2gQ75e6ktGAr4ikU6o1/+8DhcCzXtniFefcFc655Wa2CHibUDnoKudcvbfNlcADQCfgae8vKyWq5ze3iqNqj4j4IaXk75wblmDZXcBdMdprgONTOW42aG5HXh1+EfFDYL/he9YJff0OIYoGfEUknQKb/C+eWJl0neULZiW8KFusmTyThvYEWl7OUc1fRNIpsMm/Obm5S2EeM0b1Yc3dc5q938tPGwI0v5yjDr+I+CG4yT9J1h3Vr7Rd9tuYOvwi4ofAJv+2FCvft7RHr5q/iKRTYJN/W+TaMeVd6V1SyPUzh6e8L9X8RSSdAvszjm3Rzy4uzOe12xJ+ATp5HOrwi4gP1PNPZR8JlqkjLyKZLLDJP1PoTUJE/BDg5J+465/qJwMN+IpIJgts8m/vXNvSHr0GfEUknYKb/P0OwKMOv4j4IbjJv52zbnN3rw6/iPghsMk/06jmLyLpFNjkn2mpVjV/EUmn4Cb/JNk/XR1xdfhFxA/BTf4Z1/cXEUmfwCb/TKFqj4j4IbDJX+UWEQmywCb/9jZ1RBnFhXlcOmlwwvX0JiQifgjsVT3bW1lpEcsWzPI7DBGRmALb88+UWnumxCEiwRLY5B9peYweerpnA+lLXiKSToEt+7R1rh1WVsyw3sWt3l5f8hKRdAps8m9rz91wequ2U4dfRPygsg9KwCISPIFN/smqLOl6Q1C1R0T8ENjkLyISZIFN/pE9ez+v86OSk4j4IbDJX0QkyAKb/DOl1p4pcYhIsAQ2+WcafclLRNJJyZ/MqLvrS14ikk6BTf5Jf8krPWFkxBuPiARPYJO/iEiQtUnyN7MbzcyZWa+ItlvMbJWZrTSzWRHt48xsqbfsfvOp2J0pVZZMiUNEgiXl5G9mFcBM4MOItlHAPGA0MBv4gZnleot/CMwHqry/2anGICIiLdMWPf//Br4MRPZh5wIPOecOOedWA6uA8WbWDyh1zr3sQiOcDwLntkEMWUs1fxHxQ0rJ38zOAdY755Y0WjQAWBdxv9ZrG+Ddbtweb//zzazGzGq2bNmSSqgJKQGLSNAkvaSzmT0H9I2x6DbgVuDMWJvFaHMJ2mNyzi0EFgJUV1e3aXU8acLXO4KIdGBJk79zbkasdjM7ARgMLPHGbMuBN8xsPKEefUXE6uXABq+9PEZ72mmgVUSCrNVlH+fcUudcmXOu0jlXSSixn+yc+wh4AphnZoVmNpjQwO5i59xGYI+ZTfBm+VwMPJ76aaTGzwu7iYj4oV1+ycs5t9zMFgFvA3XAVc65em/xlcADQCfgae9PRETSqM2Sv9f7j7x/F3BXjPVqgOPb6rip+MrskezYf9jvMERE0i7Qv+F75RlDAThSf9S3GDT2ICJ+0OUdREQCSMk/Dl3YTUQ6MiV/EZEACmzyj/4NXxGRYAlk8v/chEGM7l/qdxgiIr4JZPL/4oyqpD+bqFq8iHRkgUz+jRN/rDcCTcEUkY4smMnf7wBERHwWyOTfHCr7iEhHFsjkHy+x5yjhi0hABDL5i4gEXSCTf+NLOIfvje7fNf3BiIj4IJDJv/GIb06O8cgVE/nV58fHW4Xy7p3aPy4RkTQJ9FU9I51S2SPh8ljjBL1LCtspGhGR9hXInn9bzeQ5saJb2+xIRCTNApn820JVWbHfIYiItFogk79mdIpI0AUy+bdGe/3Ie1F+LgCDe3Vpl/2LiMQSyAHfZBd1i8XRPhf7Ke/emZ9fWp10wFlEpC0FM/k3Z500Xt9h2sg+aTuWiAio7BPXZycMTLhc1/4RkWwWzJ5/gsT9ty9PZcPOA5w6pGf6AhIRSbNAJv9EKnp0pqJH5ybtTS8Joa6/iGSvQJZ9lLhFJOgCmfxbo71m+4iI+CGQyV+DtSISdIFM/iIiQafk30xNBnz16UFEslggk78St4gEXSCTf2towFdEOpIOn/zr6o82adNUTxEJug6f/F9cucXvEEREMk6HT/71rmm5pjU1f31aEJGOJOXkb2bXmNlKM1tuZvdEtN9iZqu8ZbMi2seZ2VJv2f3WzpfPzImx+7Y4YDqv+iki0tZSuraPmU0F5gJjnHOHzKzMax8FzANGA/2B58xsuHOuHvghMB94BXgKmA08nUocCWNsrx2LiGSxVHv+VwJ3O+cOATjnNnvtc4GHnHOHnHOrgVXAeDPrB5Q65152zjngQeDcFGNIKCfGGWbSj7mIiPgh1eQ/HDjNzF41s7+a2Sle+wBgXcR6tV7bAO924/aYzGy+mdWYWc2WLa0buG2vWr0+UYhINkta9jGz54C+MRbd5m3fHZgAnAIsMrMhxM6NLkF7TM65hcBCgOrq6lZ1vWP12FuTuDXgKyIdSdLk75ybEW+ZmV0JPOaVcBab2VGgF6EefUXEquXABq+9PEZ7u4kxzV9EJPBSLfv8HpgGYGbDgQJgK/AEMM/MCs1sMFAFLHbObQT2mNkEb5bPxcDjKcaQUP3RtpnqGXbttGEpRCMikhlS/SWvnwM/N7NlwGHgEu9TwHIzWwS8DdQBV3kzfSA0SPwA0InQLJ92m+kDcMX/vN6kLZVpmiP6lnr7aPUuRER8l1Lyd84dBj4bZ9ldwF0x2muA41M5rh8020dEOpIO/w3fthZ+E1DPX0SymZJ/M4Vn+8S4WoSISNZR8hcRCSAl/xZSuUdEOgIl/2bSgK+IdCRK/i0UrvnrG78iks2U/JupYcDX5zhERNqCkn8rqfYvItlMyV9EJICU/EVEAkjJv5nmju0PQLdO+T5HIiKSOiX/Zrp+xnCWL5hFqZf8VfIXkWym5N9MOTlGl8JUL4IqIpIZlPxbyOniPiLSASj5t5bmeopIFlPyFxEJICV/EZEAUvJvIVX8RaQjUPJvJVX8RSSbKfm3kCb7iEhHoOTfSprsIyLZTMlfRCSAlPxFRAJIyV9EJICU/FtMI74ikv2U/FtJ470iks2U/FtIUz1FpCNQ8m8l01xPEcliSv4iIgGk5C8iEkBK/iIiAaTk30Ia7xWRjqDDJ//y7p2i7j9x9cfaZL8a7hWRbNbhk//0kWVR98eUd0tpf5rqKSIdQUrJ38zGmtkrZvammdWY2fiIZbeY2SozW2lmsyLax5nZUm/Z/dbOcyZvnXNcu+xXMz1FJJul2vO/B1jgnBsL3OHdx8xGAfOA0cBs4Admlutt80NgPlDl/c1OMYaECvNy6VyQm3zFZureOR+A4X1K2myfIiLplpfi9g4o9W53BTZ4t+cCDznnDgGrzWwVMN7M1gClzrmXAczsQeBc4OkU40ibqj4lLPrCRE6s6Op3KCIirZZq8v8i8IyZ/SehTxGTvPYBwCsR69V6bUe8243bYzKz+YQ+JTBw4MBWB9nWdfrxg3u07Q5FRNIsafI3s+eAvjEW3QZMB653zv3OzC4AfgbMIPZkGJegPSbn3EJgIUB1dbWGWkVE2kjS5O+cmxFvmVe2uc67+wjwU+92LVARsWo5oZJQrXe7cbuIiKRRqgO+G4DTvdvTgPe8208A88ys0MwGExrYXeyc2wjsMbMJ3iyfi4HHU4xBRERaKNWa/+XAfWaWBxzEq88755ab2SLgbaAOuMo5V+9tcyXwANCJ0EBv1gz2ioh0FCklf+fc34FxcZbdBdwVo70GOD6V47aU00UZRESidPhv+IK+lSsi0lggkr+IiEQLRPLXpRhERKIFIvmr7CMiEi0Yyd/vAEREMkwgkr+IiERT8hcRCSAlfxGRAApG8lfRX0QkSjCSv4iIRFHyFxEJICV/EZEACkTyv+f8MQCM0O/uiogAAUn+Q3sXA5Cfp+s8iIhAQJK/iIhEC0Tyz/HOsjAv199AREQyRKq/5JUVRvUr5brpVcwbX5F8ZRGRAAhE8jczrp853O8wREQyRiDKPiIiEk3JX0QkgJT8RUQCSMlfRCSAlPxFRAJIyV9EJICU/EVEAkjJX0QkgMy57PiZKzPbAqxt5ea9gK1tGE66KX5/ZXv8kP3noPhbb5BzrnfjxqxJ/qkwsxrnXLXfcbSW4vdXtscP2X8Oir/tqewjIhJASv4iIgEUlOS/0O8AUqT4/ZXt8UP2n4Pib2OBqPmLiEi0oPT8RUQkgpK/iEgAdejkb2azzWylma0ys5v9jiceM1tjZkvN7E0zq/HaepjZs2b2nvdv94j1b/HOaaWZzfIp5p+b2WYzWxbR1uKYzWycd+6rzOx+MzMf47/TzNZ7z8ObZnZWBsdfYWYvmNk7ZrbczK7z2rPiOUgQf1Y8B2ZWZGaLzWyJF/8Crz0rHn8AnHMd8g/IBd4HhgAFwBJglN9xxYl1DdCrUds9wM3e7ZuBb3u3R3nnUggM9s4x14eYpwAnA8tSiRlYDEwEDHga+LiP8d8J3Bhj3UyMvx9wsne7BHjXizMrnoME8WfFc+Adq9i7nQ+8CkzIlsffOdehe/7jgVXOuQ+cc4eBh4C5PsfUEnOBX3q3fwmcG9H+kHPukHNuNbCK0LmmlXPuJWB7o+YWxWxm/YBS59zLLvQqeDBim3YVJ/54MjH+jc65N7zbe4B3gAFkyXOQIP54Mi1+55zb693N9/4cWfL4Q8cu+wwA1kXcryXxfy4/OeDPZva6mc332vo45zZC6IUClHntmXxeLY15gHe7cbufrjazt7yyUPgje0bHb2aVwEmEep9Z9xw0ih+y5Dkws1wzexPYDDzrnMuqx78jJ/9YdbNMndf6MefcycDHgavMbEqCdbPpvMLixZxp5/JDYCgwFtgI/JfXnrHxm1kx8Dvgi8653YlWjdHm+znEiD9rngPnXL1zbixQTqgXf3yC1TMu/o6c/GuBioj75cAGn2JJyDm3wft3M/C/hMo4m7yPhHj/bvZWz+TzamnMtd7txu2+cM5t8l7QR4GfcKyclpHxm1k+ocT5a+fcY15z1jwHseLPtucAwDm3E3gRmE0WPf4dOfm/BlSZ2WAzKwDmAU/4HFMTZtbFzErCt4EzgWWEYr3EW+0S4HHv9hPAPDMrNLPBQBWhAaNM0KKYvY/Fe8xsgjfD4eKIbdIu/KL1nEfoeYAMjN873s+Ad5xz341YlBXPQbz4s+U5MLPeZtbNu90JmAGsIEsef6DjzvYJjZ1wFqFZBO8Dt/kdT5wYhxCaBbAEWB6OE+gJ/AV4z/u3R8Q2t3nntJI0zQyIEfdvCX0sP0Ko9/L51sQMVBN6gb8PfB/vW+c+xf8rYCnwFqEXa78Mjn8yofLAW8Cb3t9Z2fIcJIg/K54DYAzwLy/OZcAdXntWPP7OOV3eQUQkiDpy2UdEROJQ8hcRCSAlfxGRAFLyFxEJICV/EZEAUvIXEQkgJX8RkQD6P026hk6YktofAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):\n",
    "    observation, reward, done = env.reset(), 0., False\n",
    "    agent.reset(mode=mode)\n",
    "    episode_reward, elapsed_steps = 0., 0\n",
    "    while True:\n",
    "        action = agent.step(observation, reward, done)\n",
    "        if render:\n",
    "            env.render()\n",
    "        if done:\n",
    "            break\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        elapsed_steps += 1\n",
    "        if max_episode_steps and elapsed_steps >= max_episode_steps:\n",
    "            break\n",
    "    agent.close()\n",
    "    return episode_reward, elapsed_steps\n",
    "\n",
    "\n",
    "logging.info('==== train ====')\n",
    "episode_rewards = []\n",
    "for episode in itertools.count():\n",
    "    episode_reward, elapsed_steps = play_episode(env.unwrapped, agent,\n",
    "            max_episode_steps=env._max_episode_steps, mode='train')\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('train episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "    if np.mean(episode_rewards[-200:]) > 8:\n",
    "        break\n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "\n",
    "logging.info('==== test ====')\n",
    "episode_rewards = []\n",
    "for episode in range(100):\n",
    "    episode_reward, elapsed_steps = play_episode(env, agent)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('test episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
